{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 将你的 TVMScript 代码包装为 PyTorch 模块\n",
    "\n",
    "**作者**：[Yaoda Zhou](https://github.com/juda)\n",
    "\n",
    "本文是关于如何将 TVMScript 代码包装为 PyTorch 模块的教程。\n",
    "使用装饰器 `as_torch`，用户可以自然地将 TVMScript 代码包装成 PyTorch {class}`torch.~nn.Module`。\n",
    "\n",
    "要跟随本教程，需要安装 PyTorch。\n",
    "\n",
    "```bash\n",
    "%%shell\n",
    "pip install torch\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop.so: cannot open shared object file: No such file or directory\n",
      "  warnings.warn(\n",
      "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop_new is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop_new.so: cannot open shared object file: No such file or directory\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.benchmark as benchmark\n",
    "\n",
    "import tvm\n",
    "from tvm.contrib.torch import as_torch\n",
    "from tvm.script import tir as T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用 TVMScript 编写自己的 PyTorch 算子\n",
    "\n",
    "PyTorch 是非常流行的机器学习框架，其中包含了大多数常用算子的优化实现。尽管如此，有时你可能想在 PyTorch 中编写自己的算子。在这种情况下，这些自定义算子的性能可能无法满足你的需求。\n",
    "\n",
    "例如，假设我们要定义 1-d 深度卷积算子，输入通道数和输出通道数都是 70，宽度是 80，卷积核大小是 20，那么1-d深度卷积可以在 PyTorch 中用一行代码来表示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_channel = 70\n",
    "out_channel = 70\n",
    "width = 80\n",
    "kernel_size = 20\n",
    "\n",
    "\n",
    "def torch_depthwise(inputs, filters):\n",
    "    return F.conv1d(inputs, filters.view(out_channel, 1, kernel_size), groups=out_channel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以这样运行函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = torch.randn(in_channel, width)\n",
    "filters = torch.randn(out_channel, kernel_size)\n",
    "ret_torch = torch_depthwise(inputs, filters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在普通的 Python 代码中，`torch_depthwise` 函数可以写成："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vanilla_depthwise(input, weight):\n",
    "    ret = torch.zeros(out_channel, width - kernel_size + 1)\n",
    "    for j in range(out_channel):\n",
    "        for i in range(width - kernel_size + 1):\n",
    "            for k in range(kernel_size):\n",
    "                ret[j, i] += weight[j, k] * input[j, i + k]\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后，计划利用 TVM 的强大功能来优化 `depthwise` 函数。TVM 社区提出了一种嵌入在 Python 中的特定领域语言，称为 TVMScript，它作为 TVM 的 Tensor IR 的高级前端。\n",
    "\n",
    "上面的深度卷积 1D 代码可以按照如下方式转换为 TVMScript。我们提供了 `as_torch` 装饰器，它会自动将 TVMScript 代码转换为 PyTorch 的 `nn.Module`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@as_torch\n",
    "@T.prim_func\n",
    "def tvm_depthwise(\n",
    "    A: T.Buffer((70, 80), \"float32\"),\n",
    "    B: T.Buffer((70, 20), \"float32\"),\n",
    "    C: T.Buffer((70, 61), \"float32\"),\n",
    ") -> None:\n",
    "    for j, i, k in T.grid(70, 61, 20):\n",
    "        with T.block():\n",
    "            vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n",
    "            with T.init():\n",
    "                C[vj, vi] = T.float32(0)\n",
    "            C[vj, vi] += B[vj, vk] * A[vj, vi + vk]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以通过调用默认设置下的 `tune` 方法来构建 TVMScript 代码。如果不提供额外信息，模型将会针对 CPU 进行优化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-03-20 12:10:34 [INFO] Logging directory: /tmp/tmphj33434s/logs\n",
      "2024-03-20 12:10:48 [INFO] LocalBuilder: max_workers = 24\n",
      "2024-03-20 12:10:50 [INFO] LocalRunner: max_workers = 1\n",
      "2024-03-20 12:10:51 [INFO] [task_scheduler.cc:159] Initializing Task #0: \"main\"\n",
      "2024-03-20 12:10:51 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: \"main\"\n",
      "2024-03-20 12:10:51 [INFO] [task_scheduler.cc:193] Sending 32 sample(s) to builder\n",
      "2024-03-20 12:10:56 [INFO] [task_scheduler.cc:195] Sending 32 sample(s) to runner\n",
      "2024-03-20 12:11:04 [DEBUG] XGB iter   0: tr-p-rmse: 0.394348\ttr-a-peak@32: 0.999438\ttr-rmse: 0.394544\ttr-rmse: 0.394544\n",
      "2024-03-20 12:11:04 [DEBUG] XGB iter  25: tr-p-rmse: 0.013164\ttr-a-peak@32: 0.999686\ttr-rmse: 0.012791\ttr-rmse: 0.012791\n",
      "2024-03-20 12:11:04 [DEBUG] XGB iter  50: tr-p-rmse: 0.013063\ttr-a-peak@32: 0.999686\ttr-rmse: 0.012660\ttr-rmse: 0.012660\n",
      "2024-03-20 12:11:04 [DEBUG] XGB iter  75: tr-p-rmse: 0.013063\ttr-a-peak@32: 0.999686\ttr-rmse: 0.012660\ttr-rmse: 0.012660\n",
      "2024-03-20 12:11:04 [DEBUG] XGB stopped. Best iteration: [34] tr-p-rmse:0.01306\ttr-a-peak@32:0.99969\ttr-rmse:0.01266\ttr-rmse:0.01266 \n",
      "2024-03-20 12:11:04 [INFO] [task_scheduler.cc:237] [Updated] Task #0: \"main\"\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Name</th>\n",
       "      <th>FLOP</th>\n",
       "      <th>Weight</th>\n",
       "      <th>Speed (GFLOPS)</th>\n",
       "      <th>Latency (us)</th>\n",
       "      <th>Weighted Latency (us)</th>\n",
       "      <th>Trials</th>\n",
       "      <th>Done</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>main</td>\n",
       "      <td>170800</td>\n",
       "      <td>1</td>\n",
       "      <td>13.1876</td>\n",
       "      <td>12.9516</td>\n",
       "      <td>12.9516</td>\n",
       "      <td>32</td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Name      FLOP    Weight    Speed (GFLOPS)    Latency (us)   \\\n",
       "0   main    170800         1           13.1876         12.9516    \n",
       "\n",
       "    Weighted Latency (us)    Trials    Done   \n",
       "0                 12.9516        32           "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Total trials: 32\n",
      "Total latency (us): 12.9516\n",
      "\n",
      "2024-03-20 12:11:04 [DEBUG] [task_scheduler.cc:318] \n",
      " ID | Name |   FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done \n",
      "-----------------------------------------------------------------------------------------------------\n",
      "  0 | main | 170800 |      1 |        13.1876 |      12.9516 |               12.9516 |     32 |      \n",
      "-----------------------------------------------------------------------------------------------------\n",
      "Total trials: 32\n",
      "Total latency (us): 12.9516\n",
      "\n",
      "2024-03-20 12:11:04 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Name</th>\n",
       "      <th>FLOP</th>\n",
       "      <th>Weight</th>\n",
       "      <th>Speed (GFLOPS)</th>\n",
       "      <th>Latency (us)</th>\n",
       "      <th>Weighted Latency (us)</th>\n",
       "      <th>Trials</th>\n",
       "      <th>Done</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>main</td>\n",
       "      <td>170800</td>\n",
       "      <td>1</td>\n",
       "      <td>13.1876</td>\n",
       "      <td>12.9516</td>\n",
       "      <td>12.9516</td>\n",
       "      <td>32</td>\n",
       "      <td>Y</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Name      FLOP    Weight    Speed (GFLOPS)    Latency (us)   \\\n",
       "0   main    170800         1           13.1876         12.9516    \n",
       "\n",
       "    Weighted Latency (us)    Trials    Done   \n",
       "0                 12.9516        32       Y   "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-03-20 12:11:04 [DEBUG] [task_scheduler.cc:318] \n",
      " ID | Name |   FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done \n",
      "-----------------------------------------------------------------------------------------------------\n",
      "  0 | main | 170800 |      1 |        13.1876 |      12.9516 |               12.9516 |     32 |    Y \n",
      "-----------------------------------------------------------------------------------------------------\n",
      "Total trials: 32\n",
      "Total latency (us): 12.9516\n",
      "\n",
      "\n",
      "Total trials: 32\n",
      "Total latency (us): 12.9516\n",
      "\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "as_torch requires the flag /\"USE_PT_TVMDSOOP/\" set in config.cmake",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtvm_depthwise\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtune\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:107\u001b[0m, in \u001b[0;36mOperatorModuleWrapper.tune\u001b[0;34m(self, target, max_trials_global, num_trials_per_iter, builder, runner, database, cost_model, measure_callbacks, task_scheduler, space, strategy, num_tuning_cores, seed)\u001b[0m\n\u001b[1;32m    105\u001b[0m sch \u001b[38;5;241m=\u001b[39m ms\u001b[38;5;241m.\u001b[39mtir_integration\u001b[38;5;241m.\u001b[39mcompile_tir(database, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mir_module, target)\n\u001b[1;32m    106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mir_module \u001b[38;5;241m=\u001b[39m sch\u001b[38;5;241m.\u001b[39mmod\n\u001b[0;32m--> 107\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbuild\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:117\u001b[0m, in \u001b[0;36mOperatorModuleWrapper.build\u001b[0;34m(self, target)\u001b[0m\n\u001b[1;32m    114\u001b[0m func \u001b[38;5;241m=\u001b[39m tvm\u001b[38;5;241m.\u001b[39mget_global_func(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtvmtorch.save_runtime_mod\u001b[39m\u001b[38;5;124m\"\u001b[39m, allow_missing\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m    116\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m func \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mas_torch requires the flag /\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUSE_PT_TVMDSOOP/\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m set in config.cmake\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    118\u001b[0m func(runtime_module)\n\u001b[1;32m    120\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrt_module \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mclasses\u001b[38;5;241m.\u001b[39mtvm_torch\u001b[38;5;241m.\u001b[39mOperatorModuleWrapper()\n",
      "\u001b[0;31mValueError\u001b[0m: as_torch requires the flag /\"USE_PT_TVMDSOOP/\" set in config.cmake"
     ]
    }
   ],
   "source": [
    "tvm_depthwise.tune()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以打印出优化后的 TVMScript 代码，以查看程序是如何被转换的，如下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tvm_depthwise.script())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以验证这两个输出是相同的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "as_torch requires the flag /\"USE_PT_TVMDSOOP/\" set in config.cmake",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m ret_tvm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros(out_channel, width \u001b[38;5;241m-\u001b[39m kernel_size \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtvm_depthwise\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilters\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mret_tvm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      4\u001b[0m testing\u001b[38;5;241m.\u001b[39massert_allclose(ret_torch\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy(), ret_tvm\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy(), atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m, rtol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m)\n",
      "File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:127\u001b[0m, in \u001b[0;36mOperatorModuleWrapper.forward\u001b[0;34m(self, *torch_inputs)\u001b[0m\n\u001b[1;32m    125\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuild(target\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    126\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m torch_inputs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdevice\u001b[38;5;241m.\u001b[39mtype \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 127\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbuild\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    129\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthe target \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtorch_inputs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdevice\u001b[38;5;241m.\u001b[39mtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not supported yet\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:117\u001b[0m, in \u001b[0;36mOperatorModuleWrapper.build\u001b[0;34m(self, target)\u001b[0m\n\u001b[1;32m    114\u001b[0m func \u001b[38;5;241m=\u001b[39m tvm\u001b[38;5;241m.\u001b[39mget_global_func(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtvmtorch.save_runtime_mod\u001b[39m\u001b[38;5;124m\"\u001b[39m, allow_missing\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m    116\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m func \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mas_torch requires the flag /\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUSE_PT_TVMDSOOP/\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m set in config.cmake\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    118\u001b[0m func(runtime_module)\n\u001b[1;32m    120\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrt_module \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mclasses\u001b[38;5;241m.\u001b[39mtvm_torch\u001b[38;5;241m.\u001b[39mOperatorModuleWrapper()\n",
      "\u001b[0;31mValueError\u001b[0m: as_torch requires the flag /\"USE_PT_TVMDSOOP/\" set in config.cmake"
     ]
    }
   ],
   "source": [
    "ret_tvm = torch.zeros(out_channel, width - kernel_size + 1)\n",
    "tvm_depthwise(inputs, filters, ret_tvm)\n",
    "\n",
    "testing.assert_allclose(ret_torch.cpu().numpy(), ret_tvm.cpu().numpy(), atol=1e-5, rtol=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for i in range(5):\n",
    "    inputs = torch.randn(out_channel, width)\n",
    "    filters = torch.randn(out_channel, kernel_size)\n",
    "    res = torch.zeros(out_channel, width - kernel_size + 1)\n",
    "    sub_label = f\"[test {i}]\"\n",
    "    results.append(\n",
    "        benchmark.Timer(\n",
    "            stmt=\"tvm_depthwise(inputs, filters, res)\",\n",
    "            setup=\"from __main__ import tvm_depthwise\",\n",
    "            globals={\"inputs\": inputs, \"filters\": filters, \"res\": res},\n",
    "            sub_label=sub_label,\n",
    "            description=\"TVMScript\",\n",
    "        ).blocked_autorange()\n",
    "    )\n",
    "    results.append(\n",
    "        benchmark.Timer(\n",
    "            stmt=\"torch_depthwise(inputs, filters)\",\n",
    "            setup=\"from __main__ import torch_depthwise\",\n",
    "            globals={\n",
    "                \"inputs\": inputs,\n",
    "                \"filters\": filters,\n",
    "            },\n",
    "            sub_label=sub_label,\n",
    "            description=\"PyTorch\",\n",
    "        ).blocked_autorange()\n",
    "    )\n",
    "compare = benchmark.Compare(results)\n",
    "compare.print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在作者的环境中，`tvm_depthwise` 的平均推理时间是120.0微秒，而 `torch_depthwise` 的平均推理时间是196.0微秒（PyTorch版本是1.11.0），显示出大约38%的速度提升。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
